Работа с данными
# задаем параметры
set.seed(1)
angles <- seq(0, 360, 10)
shifts <- seq(-10, 10, 2)
num_of_patches <- 168
num_of_replicas <- 50
dir.create('./augmented_data/')
# вспомогательная функция для записи аугментированных изображений
write_aug_image <- function(file, image, ind, angle, shift_cols, shift_rows) {
imageAugmented <- Augmentation(image, flip_mode = "horizontal",
shift_cols = shift_cols, shift_rows = shift_rows,
rotate_angle = angle, rotate_method = 'bilinear',
zca_comps = 30,zca_epsilon = 0.1, threads = 1, verbose = F)
writeImage(imageAugmented, sprintf("./augmented_data/%s_%i_%i_%i_%i.jpg", file, ind, angle, shift_cols, shift_rows))
}
# создаем аугментированные изображения в папке ./augmented_data/
for (i in 1:num_of_patches) {
if (i < 10){
outfile <- sprintf("patch00%d", i)
} else if (i < 100) {
outfile <- sprintf("patch0%d", i)
} else {
outfile <- sprintf("patch%d", i)
}
file <- sprintf("patch%d.jpg", i)
# print(outfile)
image <- readImage(sprintf("./patches/%s", file))
augment_param = data.frame(angle = sample(angles, num_of_replicas, replace = T), shift_cols = sample(shifts, num_of_replicas, replace = T), shift_rows = sample(shifts, num_of_replicas, replace = T), ind = 1:num_of_replicas)
apply(augment_param, 1, function(x) write_aug_image(outfile, image, x[4], x[1], x[2], x[3]))
}
# создаем матрицу изображений данных
features <- 61 * 61
dataset.size <- num_of_replicas * num_of_patches
nn.data.x <- matrix(0, nrow=dataset.size, ncol=features)
index = 1
for (file in list.files('./augmented_data/')){
# print(sprintf('%s %i / %i',file, index, dataset.size))
nn.data.x[index,] <- as.numeric(readImage(sprintf("./augmented_data/%s", file)))
index = index + 1
}
# создаем вектор решений
patch_labels <- read.csv('patch_labels.csv', header = F)
nn.data.y <- c(rep(1, sum(patch_labels == 1) * num_of_replicas), rep(2, sum(patch_labels == 2) * num_of_replicas), rep(3, sum(patch_labels == 3) * num_of_replicas))
Архитектура нейронной сети
train.array <- t(train.x)
dim(train.array) <- c(61, 61, 1, ncol(train.array))
test.array <- t(test.x)
dim(test.array) <- c(61, 61, 1, ncol(test.array))
# Слой входных данных
data <- mx.symbol.Variable('data')
# Сверточный слой 1
conv.1 <- mx.symbol.Convolution(data = data, kernel = c(5, 5), num_filter = 10)
# Активационный слой 1
tanh.1 <- mx.symbol.Activation(data = conv.1, act_type = "tanh")
# Слой пулинга 1
pool.1 <- mx.symbol.Pooling(data=tanh.1, kernel=c(2, 2), stride=c(2, 2), pool.type="max")
# Сверточный слой 2
conv.2 <- mx.symbol.Convolution(data = pool.1, kernel = c(5, 5), num_filter = 10)
# Активационный слой 2
tanh.2 <- mx.symbol.Activation(data = conv.2, act_type = "tanh")
# Слой пулинга 2
pool.2 <- mx.symbol.Pooling(data = tanh.2, kernel=c(2, 2), stride=c(2, 2), pool.type="max")
# FullyConnected слой 1
fc.1 <- mx.symbol.FullyConnected(data = pool.2, num_hidden = 3)
# Softmax выходной слой
nn.model <- mx.symbol.SoftmaxOutput(data = fc.1)
graph.viz(nn.model)
mx.set.seed(1)
model <- mx.model.FeedForward.create(nn.model,
X=train.array,
y=as.array(train.y-1),
eval.data = list(
data=test.array,
label=as.array(test.y-1)
),
ctx=mx.cpu(),
num.round = 40,
optimizer="adadelta",
eval.metric = mx.metric.accuracy,
epoch.end.callback = mx.callback.log.train.metric(10))
## Start training with 1 devices
## [1] Train-accuracy=0.412860576923077
## [1] Validation-accuracy=0.540178571428571
## [2] Train-accuracy=0.524174528301887
## [2] Validation-accuracy=0.607142857142857
## [3] Train-accuracy=0.553803066037736
## [3] Validation-accuracy=0.638392857142857
## [4] Train-accuracy=0.571933962264151
## [4] Validation-accuracy=0.498325892857143
## [5] Train-accuracy=0.589475235849057
## [5] Validation-accuracy=0.551339285714286
## [6] Train-accuracy=0.602152122641509
## [6] Validation-accuracy=0.637276785714286
## [7] Train-accuracy=0.613797169811321
## [7] Validation-accuracy=0.430803571428571
## [8] Train-accuracy=0.616597877358491
## [8] Validation-accuracy=0.61328125
## [9] Train-accuracy=0.626621462264151
## [9] Validation-accuracy=0.633928571428571
## [10] Train-accuracy=0.642246462264151
## [10] Validation-accuracy=0.535714285714286
## [11] Train-accuracy=0.660524764150943
## [11] Validation-accuracy=0.635602678571429
## [12] Train-accuracy=0.660819575471698
## [12] Validation-accuracy=0.680803571428571
## [13] Train-accuracy=0.676591981132076
## [13] Validation-accuracy=0.590401785714286
## [14] Train-accuracy=0.686173349056604
## [14] Validation-accuracy=0.640625
## [15] Train-accuracy=0.693985849056604
## [15] Validation-accuracy=0.631138392857143
## [16] Train-accuracy=0.71373820754717
## [16] Validation-accuracy=0.546875
## [17] Train-accuracy=0.717865566037736
## [17] Validation-accuracy=0.571986607142857
## [18] Train-accuracy=0.739239386792453
## [18] Validation-accuracy=0.64453125
## [19] Train-accuracy=0.751474056603774
## [19] Validation-accuracy=0.505022321428571
## [20] Train-accuracy=0.764298349056604
## [20] Validation-accuracy=0.573102678571429
## [21] Train-accuracy=0.770931603773585
## [21] Validation-accuracy=0.635602678571429
## [22] Train-accuracy=0.773584905660377
## [22] Validation-accuracy=0.6171875
## [23] Train-accuracy=0.802771226415094
## [23] Validation-accuracy=0.616629464285714
## [24] Train-accuracy=0.810731132075472
## [24] Validation-accuracy=0.631696428571429
## [25] Train-accuracy=0.801739386792453
## [25] Validation-accuracy=0.62109375
## [26] Train-accuracy=0.83313679245283
## [26] Validation-accuracy=0.592633928571429
## [27] Train-accuracy=0.829893867924528
## [27] Validation-accuracy=0.580915178571429
## [28] Train-accuracy=0.839475235849057
## [28] Validation-accuracy=0.631138392857143
## [29] Train-accuracy=0.844634433962264
## [29] Validation-accuracy=0.622209821428571
## [30] Train-accuracy=0.847435141509434
## [30] Validation-accuracy=0.638392857142857
## [31] Train-accuracy=0.847877358490566
## [31] Validation-accuracy=0.580357142857143
## [32] Train-accuracy=0.863649764150943
## [32] Validation-accuracy=0.620535714285714
## [33] Train-accuracy=0.872641509433962
## [33] Validation-accuracy=0.629464285714286
## [34] Train-accuracy=0.874410377358491
## [34] Validation-accuracy=0.580357142857143
## [35] Train-accuracy=0.876031839622642
## [35] Validation-accuracy=0.560267857142857
## [36] Train-accuracy=0.87721108490566
## [36] Validation-accuracy=0.614397321428571
## [37] Train-accuracy=0.891362028301887
## [37] Validation-accuracy=0.579241071428571
## [38] Train-accuracy=0.888266509433962
## [38] Validation-accuracy=0.628348214285714
## [39] Train-accuracy=0.906692216981132
## [39] Validation-accuracy=0.6015625
## [40] Train-accuracy=0.906839622641509
## [40] Validation-accuracy=0.613839285714286
preds_test <- apply(predict(model, test.array), 2, which.max)
preds_train <- apply(predict(model, train.array), 2, which.max)
results <- data.frame(test_accuracy= sum(preds_test == test.y) / length(test.y) * 100, train_accuracy= sum(preds_train == train.y) / length(train.y) * 100)
results
## test_accuracy train_accuracy
## 1 61.23529 87.0597